#############################################################################
# Function that returns coefficients per county and mg per county-year
#############################################################################

# Function that takes a data, re-estimate the model with heterogeneous slopes, 
# and returns: slopes per county and (mg)

heterog_sub = function(data) {
  
  # Estimate
  f = feols(yield ~  precr + ddr:factor(fips) | factor(fips) + factor(year)^factor(state) ,data=data)

  # Save slopes in table
  dt_slop = slopes_dt(f=f,names_var="")
  
  # Mg effects 
  dt_slop = merge(dt_slop,data[,.SD,.SDcols=c("fips","year","ddr","corn_area")],by="fips",all.x=TRUE)
  dt_slop[,mg:= coef]
  
  # Subset of variables
  dt_slop=dt_slop[,.(fips,year,corn_area,mg)]
  
  return(dt_slop)
  
}


#############################################################################
# Function for jackknife 
#############################################################################

# Arguments:
# - Data
# - Trim, proportion of counties with lower sd(ddr) to be trimmed
# - Sel: how to split the counties

# Returns:
# - Descriptive table of mg: statistics (applying formula to each row, hence not from CDF)
# - Data at the county-year level with the marginal effects and slopes
# - Number of observation of the data for the cases where there was trimming

jack2_split = function(data,trim=0,sel=0,sseed=1,nsplit=1) {
  
  set.seed(sseed)
  
  # Trim data
  dt=copy(data)
  if (trim>0) {
    qq = as.matrix(dt[nobs==1,lapply(.SD,quantile,probs=trim),.SDcols="ddr_sd"])
    dt = dt[ddr_sd>qq[1]]
  }
  
  # Overall
  dt_s0 =copy(dt)
  dim(dt_s0)
  het_s0 = heterog_sub(data=dt_s0)
  setnames(het_s0,"mg","mg0")
  
  # Sample 1a: all counties, first half of periods
  year_list = sort(unique(dt$year))
  l = length(year_list)
  dt_s1 = dt_s0[year %in% year_list[1:floor(l/2)]]
  het_s1a = heterog_sub(data=dt_s1)
  setnames(het_s1a,"mg","mg1a")
  het_s1a[,corn_area:=NULL]
  
  # Sample 1b: all counties, second half of periods
  dt_s1 = dt_s0[!(year %in% year_list[1:floor(l/2)])]
  het_s1b = heterog_sub(data=dt_s1)
  setnames(het_s1b,"mg","mg1b")
  het_s1b[,corn_area:=NULL]              
  
  # Merge (so far)
  dt_jack = merge(het_s0,het_s1a,by=c("fips","year"),all=TRUE)
  dt_jack = merge(dt_jack,het_s1b,by=c("fips","year"),all=TRUE)
  
  # Paste corn_area
  dt_s0[,area_state:=lapply(.SD,sum),.SDcols="corn_area",by="state"]

  # Split counties
  for (k in seq(1,nsplit,1)) {
    
    set.seed(sseed+k)
    
    # Sort counties randomly within state
    dt_s0[,r:=runif(nrow(dt_s0))*(nobs==1)]
    dt_s0[,r:=lapply(.SD,max),.SDcols="r",by="fips"]
    setkey(dt_s0,state,r,fips,year)
    
    # Select, per state, the first counties for half the weight
    dt_s0[,aux:=corn_area/area_state]
    dt_s0[,aux:=cumsum(aux),by="state"]
    dt_s0[,max_aux:=lapply(.SD,max),.SDcols="aux",by="fips"]
    dt_s0[,min_aux:=lapply(.SD,min),.SDcols="aux",by="fips"]
    dt_s0[max_aux<0.5,sample:=1]
    dt_s0[min_aux>=0.5,sample:=2]
    # For those counties in between, choose randomly
    dt_s0[,ssum:=lapply(.SD,max),.SDcols="max_aux",by=c("state","sample")]
    dt_s0[,ssum_lag:=shift(ssum,n=1,type="lag")]
    dt_s0[,alea:=runif(n=nrow(dt_s0))*(ssum-ssum_lag)+ssum_lag]
    dt_s0[,addc:=as.numeric(is.na(sample)==TRUE & alea<0.5) ]
    dt_s0[,addc:=lapply(.SD,max),.SDcols="addc",by="fips"]
    dt_s0[addc==1 & is.na(sample)==TRUE,sample:=1]
    dt_s0[addc==0 & is.na(sample)==TRUE,sample:=2]
    dt_s0[,c("r","aux","max_aux","min_aux","ssum","ssum_lag","alea","addc"):=NULL]
    table(dt_s0$sample,useNA="always")
  
    # Sample 2a: first half of counties, all periods
    dt_s2 = dt_s0[sample==1]
    het_s2a = heterog_sub(data=dt_s2)
    setnames(het_s2a,"mg",paste0("mg2a_s",k))
    het_s2a[,corn_area:=NULL]
    
    # Sample 2b: second half of counties, all periods
    dt_s2 = dt_s0[sample==2]
    het_s2b = heterog_sub(data=dt_s2)
    setnames(het_s2b,"mg",paste0("mg2b_s",k))
    het_s2b[,corn_area:=NULL]

    # Merge
    dt_jack = merge(dt_jack,het_s2a,by=c("fips","year"),all=TRUE)
    dt_jack = merge(dt_jack,het_s2b,by=c("fips","year"),all=TRUE)

  }
  dim(dt_jack)
  summary(dt_jack)
  
  # Tabulate mean and variance of mg effects, original and bias corrected
  nnames = c("mg0","mg1a","mg1b",paste0("mg2a_s",1:nsplit),paste0("mg2b_s",1:nsplit))
  R = dt_jack[,lapply(.SD,wtd.quantile,probs=c(0.1,0.25,0.5,0.75,0.9),weights=corn_area,na.rm=TRUE),.SDcols=nnames]
  R = rbind(R,dt_jack[,lapply(.SD,wtd.mean,weights=corn_area,na.rm=TRUE),.SDcols=nnames])
  R = rbind(R,dt_jack[,lapply(.SD,wtd.var,weights=corn_area,na.rm=TRUE),.SDcols=nnames])
  
  # Average across county-splits
  R[,mg2a:=0]
  R[,mg2b:=0]
  for (k in seq(1,nsplit,1)) {
    R[,mg2a:=mg2a+.SD,.SDcols=paste0("mg2a_s",k)]  
    R[,mg2b:=mg2b+.SD,.SDcols=paste0("mg2b_s",k)]
  }
  R[,mg2a:=mg2a/nsplit]
  R[,mg2b:=mg2b/nsplit]
  
  # Generate jackknife version
  R[,corr:=3*mg0-(mg1a+mg1b)/2-(mg2a+mg2b)/2]
  R=R[,.SD,.SDcols=c("mg0","mg1a","mg1b",paste0("mg2a_s",1:nsplit),"mg2a",paste0("mg2b_s",1:nsplit),"mg2b","corr")]
  R = as.matrix(R)
  colnames(R)=c("tot","samp_year1","samp_year2",paste0("sample_count1_s",1:nsplit),"Avg_count1",
                paste0("sample_count2_s",1:nsplit),"Avg_count2","corr")
  rownames(R)=c(paste0("Percentil ",c(10,25,50,75,90)),"Mean","Var")
  R
  
  return(list(R=R,dt_jack=dt_jack,nobs=nrow(dt_jack),nc=length(unique(dt_jack$fips))))
}


#############################################################################
# Jackknife CDF
#############################################################################

jack2_cdf = function(data=dt_plot,npoints=10001,bval=45,nsplit=1) {
  
  dt =copy(data)
  
  # Get a sense of min and max
  bvals = seq(-bval,bval,length.out=npoints)
  table(bvals==0)
  mat_cdf=matrix(NA,nrow=npoints,ncol=3+2*nsplit)
  
  nnames=c("mg_fe","mg1a","mg1b",paste0("mg2a_s",1:nsplit),paste0("mg2b_s",1:nsplit))
  nnames_new=c("aux_fe","aux_1a","aux_1b",paste0("aux_2as",1:nsplit),paste0("aux_2bs",1:nsplit))
  
  for (k in seq(1,npoints,1)) {
    
    f1=function(x) return(x<=bvals[k])
    dt[,(nnames_new):=lapply(.SD,f1),.SDcols=nnames]
    mat_cdf[k,1:length(nnames)] = as.matrix(dt[,lapply(.SD,wtd.mean,weights=corn_area),.SDcols=nnames_new])
    
  }
  
  #
  mat_cdf=as.data.table(mat_cdf)
  mat_cdf[,b:=bvals]
  setnames(mat_cdf,paste0("V",1:length(nnames)),nnames)
  summary(mat_cdf)
  
  # Generate average of counties-splits
  mat_cdf[,mg2a:=0]
  mat_cdf[,mg2b:=0]
  for (k in seq(1,nsplit,1)) {
    mat_cdf[,mg2a:=mg2a+.SD,.SDcols=paste0("mg2a_s",k)]  
    mat_cdf[,mg2b:=mg2b+.SD,.SDcols=paste0("mg2b_s",k)]
  }
  mat_cdf[,mg2a:=mg2a/nsplit]
  mat_cdf[,mg2b:=mg2b/nsplit]
  
  # Subset of variables and fix some names
  mat_cdf=mat_cdf[,.SD,.SDcols=c("mg_fe","mg1a","mg1b","mg2a","mg2b","b")]
  setnames(mat_cdf,c("mg_fe","mg_s1a","mg_s1b","mg_s2a","mg_s2b","b"))
  
  # Generate jackknife corrected CDF:
  mat_cdf[,cdf_jk:=3*mg_fe-(mg_s1a+mg_s1b)/2-(mg_s2a+mg_s2b)/2]
  summary(mat_cdf)
  
  # Rearranged version of the corrected function
  f = stepfun(x=mat_cdf$b, y=c(0,mat_cdf$cdf_jk), f = 0, ties = "ordered",right = FALSE)
  mf = rearrange(f=f,xmin=min(mat_cdf$b),xmax=max(mat_cdf$b))
  mat_cdf[,cdf_mon:=mf(b)]
  summary(mat_cdf)
  
  # Truncate values to be between 0 and 1
  mat_cdf[,cdf_mont:=cdf_mon]
  mat_cdf[cdf_mont<0,cdf_mont:=0]
  mat_cdf[cdf_mont>1,cdf_mont:=1]
  summary(mat_cdf)
  
  return(mat_cdf)
}

#############################################################################
# Jackknife simulated data to get implicit PDF
#############################################################################

# Simulate data from the CDF to get implicit PDF
jack2_sim = function(cdf_data,Nsim=10000,sseed=1) {
  
  mat_cdf =copy(cdf_data)
  set.seed(sseed)
  pdf2=c()
  minb=min(mat_cdf$b)
  for (k in seq(1,Nsim,1)) {
    r=runif(n=1)
    pdf2 = rbind(pdf2,c(max(mat_cdf[mg_fe<=r,b],minb),max(mat_cdf[mg_s1a<=r,b],minb),max(mat_cdf[mg_s1b<=r,b],minb),
                        max(mat_cdf[mg_s2a<=r,b],minb),max(mat_cdf[mg_s2b<=r,b],minb),max(mat_cdf[cdf_jk<=r,b],minb),
                        max(mat_cdf[cdf_mon<=r,b],minb),max(mat_cdf[cdf_mont<=r,b],minb),r))
    
  }
  pdf_sim = data.table(pdf2)
  rm(pdf2,Nsim)
  setnames(pdf_sim,c("mg_fe","mg_s1a","mg_s1b","mg_s2a","mg_s2b","cdf_jk","cdf_mon","cdf_mont","r"))
  summary(pdf_sim)
  
  return(pdf_sim)
  
}
